{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "91c6a7ef",
   "metadata": {},
   "source": [
    "# Dynamodb Chat Message History\n",
    "\n",
    "This notebook goes over how to use Dynamodb to store chat message history."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f608be0",
   "metadata": {},
   "source": [
    "First make sure you have correctly configured the [AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html). Then make sure you have installed boto3."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "030d784f",
   "metadata": {},
   "source": [
    "Next, create the DynamoDB Table where we will be storing messages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "93ce1811",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "import boto3\n",
    "\n",
    "# Get the service resource.\n",
    "dynamodb = boto3.resource(\"dynamodb\")\n",
    "\n",
    "# Create the DynamoDB table.\n",
    "table = dynamodb.create_table(\n",
    "    TableName=\"SessionTable\",\n",
    "    KeySchema=[{\"AttributeName\": \"SessionId\", \"KeyType\": \"HASH\"}],\n",
    "    AttributeDefinitions=[{\"AttributeName\": \"SessionId\", \"AttributeType\": \"S\"}],\n",
    "    BillingMode=\"PAY_PER_REQUEST\",\n",
    ")\n",
    "\n",
    "# Wait until the table exists.\n",
    "table.meta.client.get_waiter(\"table_exists\").wait(TableName=\"SessionTable\")\n",
    "\n",
    "# Print out some data about the table.\n",
    "print(table.item_count)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a9b310b",
   "metadata": {},
   "source": [
    "## DynamoDBChatMessageHistory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d15e3302",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.memory.chat_message_histories import DynamoDBChatMessageHistory\n",
    "\n",
    "history = DynamoDBChatMessageHistory(table_name=\"SessionTable\", session_id=\"0\")\n",
    "\n",
    "history.add_user_message(\"hi!\")\n",
    "\n",
    "history.add_ai_message(\"whats up?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "64fc465e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "[HumanMessage(content='hi!', additional_kwargs={}, example=False),\n AIMessage(content='whats up?', additional_kwargs={}, example=False),\n HumanMessage(content='hi!', additional_kwargs={}, example=False),\n AIMessage(content='whats up?', additional_kwargs={}, example=False)]"
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history.messages"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "955f1b15",
   "metadata": {},
   "source": [
    "## DynamoDBChatMessageHistory with Custom Endpoint URL\n",
    "\n",
    "Sometimes it is useful to specify the URL to the AWS endpoint to connect to. For instance, when you are running locally against [Localstack](https://localstack.cloud/). For those cases you can specify the URL via the `endpoint_url` parameter in the constructor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "225713c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.memory.chat_message_histories import DynamoDBChatMessageHistory\n",
    "\n",
    "history = DynamoDBChatMessageHistory(\n",
    "    table_name=\"SessionTable\",\n",
    "    session_id=\"0\",\n",
    "    endpoint_url=\"http://localhost.localstack.cloud:4566\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## DynamoDBChatMessageHistory With Different Keys Composite Keys\n",
    "The default key for DynamoDBChatMessageHistory is ```{\"SessionId\": self.session_id}```, but you can modify this to match your table design.\n",
    "\n",
    "### Primary Key Name\n",
    "You may modify the primary key by passing in a primary_key_name value in the constructor, resulting in the following:\n",
    "```{self.primary_key_name: self.session_id}```\n",
    "\n",
    "### Composite Keys\n",
    "When using an existing DynamoDB table, you may need to modify the key structure from the default of to something including a Sort Key. To do this you may use the ```key``` parameter.\n",
    "\n",
    "Passing a value for key will override the primary_key parameter, and the resulting key structure will be the passed value.\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "data": {
      "text/plain": "[HumanMessage(content='hello, composite dynamodb table!', additional_kwargs={}, example=False)]"
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain.memory.chat_message_histories import DynamoDBChatMessageHistory\n",
    "\n",
    "composite_table = dynamodb.create_table(\n",
    "    TableName=\"CompositeTable\",\n",
    "    KeySchema=[{\"AttributeName\": \"PK\", \"KeyType\": \"HASH\"}, {\"AttributeName\": \"SK\", \"KeyType\": \"RANGE\"}],\n",
    "    AttributeDefinitions=[{\"AttributeName\": \"PK\", \"AttributeType\": \"S\"}, {\"AttributeName\": \"SK\", \"AttributeType\": \"S\"}],\n",
    "    BillingMode=\"PAY_PER_REQUEST\",\n",
    ")\n",
    "\n",
    "# Wait until the table exists.\n",
    "composite_table.meta.client.get_waiter(\"table_exists\").wait(TableName=\"CompositeTable\")\n",
    "\n",
    "# Print out some data about the table.\n",
    "print(composite_table.item_count)\n",
    "\n",
    "my_key = {\n",
    "    \"PK\": \"session_id::0\",\n",
    "    \"SK\":  \"langchain_history\",\n",
    "}\n",
    "\n",
    "composite_key_history = DynamoDBChatMessageHistory(\n",
    "    table_name=\"CompositeTable\",\n",
    "    session_id=\"0\",\n",
    "    endpoint_url=\"http://localhost.localstack.cloud:4566\",\n",
    "    key=my_key,\n",
    ")\n",
    "\n",
    "composite_key_history.add_user_message(\"hello, composite dynamodb table!\")\n",
    "\n",
    "composite_key_history.messages"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3b33c988",
   "metadata": {},
   "source": [
    "## Agent with DynamoDB Memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f92d9499",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.agents import Tool\n",
    "from langchain.memory import ConversationBufferMemory\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.agents import initialize_agent\n",
    "from langchain.agents import AgentType\n",
    "from langchain.utilities import PythonREPL\n",
    "from getpass import getpass\n",
    "\n",
    "message_history = DynamoDBChatMessageHistory(table_name=\"SessionTable\", session_id=\"1\")\n",
    "memory = ConversationBufferMemory(\n",
    "    memory_key=\"chat_history\", chat_memory=message_history, return_messages=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1167eeba",
   "metadata": {},
   "outputs": [],
   "source": [
    "python_repl = PythonREPL()\n",
    "\n",
    "# You can create the tool to pass to an agent\n",
    "tools = [\n",
    "    Tool(\n",
    "        name=\"python_repl\",\n",
    "        description=\"A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.\",\n",
    "        func=python_repl.run,\n",
    "    )\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "fce085c5",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValidationError",
     "evalue": "1 validation error for ChatOpenAI\n__root__\n  Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass  `openai_api_key` as a named parameter. (type=value_error)",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mValidationError\u001B[0m                           Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[17], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m llm \u001B[38;5;241m=\u001B[39m \u001B[43mChatOpenAI\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtemperature\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m      2\u001B[0m agent_chain \u001B[38;5;241m=\u001B[39m initialize_agent(\n\u001B[1;32m      3\u001B[0m     tools,\n\u001B[1;32m      4\u001B[0m     llm,\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m      7\u001B[0m     memory\u001B[38;5;241m=\u001B[39mmemory,\n\u001B[1;32m      8\u001B[0m )\n",
      "File \u001B[0;32m~/Documents/projects/langchain/libs/langchain/langchain/load/serializable.py:74\u001B[0m, in \u001B[0;36mSerializable.__init__\u001B[0;34m(self, **kwargs)\u001B[0m\n\u001B[1;32m     73\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__init__\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs: Any) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m---> 74\u001B[0m     \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__init__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     75\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_lc_kwargs \u001B[38;5;241m=\u001B[39m kwargs\n",
      "File \u001B[0;32m~/Documents/projects/langchain/.venv/lib/python3.9/site-packages/pydantic/main.py:341\u001B[0m, in \u001B[0;36mpydantic.main.BaseModel.__init__\u001B[0;34m()\u001B[0m\n",
      "\u001B[0;31mValidationError\u001B[0m: 1 validation error for ChatOpenAI\n__root__\n  Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass  `openai_api_key` as a named parameter. (type=value_error)"
     ]
    }
   ],
   "source": [
    "llm = ChatOpenAI(temperature=0)\n",
    "agent_chain = initialize_agent(\n",
    "    tools,\n",
    "    llm,\n",
    "    agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,\n",
    "    verbose=True,\n",
    "    memory=memory,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "952a3103",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_chain.run(input=\"Hello!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54c4aaf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_chain.run(input=\"Who owns Twitter?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9013118",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_chain.run(input=\"My name is Bob.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "405e5315",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_chain.run(input=\"Who am I?\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
